import collections
import warnings

import torch.cuda
from typing import Optional, Sequence, Union


__all__ = ['all_reduce', 'reduce', 'broadcast', 'all_gather', 'reduce_scatter']

SUM = 0  # ncclRedOp_t


def is_available(tensors):
    if not hasattr(torch._C, '_nccl_all_reduce'):
        warnings.warn('PyTorch is not compiled with NCCL support')
        return False

    devices = set()
    for tensor in tensors:
        if tensor.is_sparse:
            return False
        if not tensor.is_contiguous():
            return False
        if not tensor.is_cuda:
            return False
        device = tensor.get_device()
        if device in devices:
            return False
        devices.add(device)

    return True


def version():
    ver = torch._C._nccl_version()
    major = ver >> 32
    minor = (ver >> 16) & 65535
    patch = ver & 65535
    return (major, minor, patch)


def unique_id():
    return torch._C._nccl_unique_id()


def init_rank(num_ranks, uid, rank):
    return torch._C._nccl_init_rank(num_ranks, uid, rank)


def _check_sequence_type(inputs: Union[torch.Tensor, Sequence[torch.Tensor]]) -> None:
    if not isinstance(inputs, collections.Container) or isinstance(inputs, torch.Tensor):
        raise TypeError("Inputs should be a collection of tensors")


def all_reduce(inputs, outputs=None, op=SUM, streams=None, comms=None):
    _check_sequence_type(inputs)
    if outputs is None:
        outputs = inputs
    _check_sequence_type(outputs)
    torch._C._nccl_all_reduce(inputs, outputs, op, streams, comms)


# `output` used to be `outputs`, taking in a list of tensors. So we have two
# arguments for BC reasons.
def reduce(inputs: Sequence[torch.Tensor],
           output: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]] = None,
           root: int = 0,
           op: int = SUM,
           streams: Optional[Sequence[torch.cuda.Stream]] = None,
           comms=None, *,
           outputs: Optional[Sequence[torch.Tensor]] = None) -> None:
    _check_sequence_type(inputs)
    _output: torch.Tensor
    if outputs is not None:
        if output is not None:
            raise ValueError(
                "'output' and 'outputs' can not be both specified. 'outputs' is deprecated in "
                "favor of 'output', taking in a single output tensor. The signature of reduce is: "
                "reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None).")
        else:
            warnings.warn(
                "nccl.reduce with an output tensor list is deprecated. "
                "Please specify a single output tensor with argument 'output' instead instead.")
            _output = outputs[root]
    elif not isinstance(output, torch.Tensor) and isinstance(output, collections.abc.Sequence):
        # User called old API with positional arguments of list of output tensors.
        warnings.warn(
            "nccl.reduce with an output tensor list is deprecated. "
            "Please specify a single output tensor.")
        _output = output[root]
    else:
        _output = inputs[root] if output is None else output
    torch._C._nccl_reduce(inputs, _output, root, op, streams, comms)


def broadcast(inputs: Sequence[torch.Tensor], root: int = 0, streams=None, comms=None) -> None:
    _check_sequence_type(inputs)
    torch._C._nccl_broadcast(inputs, root, streams, comms)


def all_gather(inputs: Sequence[torch.Tensor], outputs: Sequence[torch.Tensor], streams=None, comms=None) -> None:
    _check_sequence_type(inputs)
    _check_sequence_type(outputs)
    torch._C._nccl_all_gather(inputs, outputs, streams, comms)


def reduce_scatter(inputs: Sequence[torch.Tensor],
                   outputs: Sequence[torch.Tensor],
                   op: int = SUM,
                   streams=None, comms=None) -> None:
    _check_sequence_type(inputs)
    _check_sequence_type(outputs)
    torch._C._nccl_reduce_scatter(inputs, outputs, op, streams, comms)
